-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Modular Diffusers Guiders #11311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: modular-refactor
Are you sure you want to change the base?
Modular Diffusers Guiders #11311
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@vladmandic's suggestion about having a universal start/stop parameter from here is now implemented too. Note, however, that the guiders should already support any kind of dynamic schedule with multiple enabling/disabling per inference if user modifies the properties on the guider object (see this comment for example). Batched inference is still supported too! (in terms of multiple prompts and setting |
@@ -0,0 +1,271 @@ | |||
# Copyright 2024 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more context on why we need this, see #10875 and this comment.
I discussed with Dhruv and for now we should keep it. After one of FBC or Guider PR is merged to main
, I can do the refactor and make use of decorators. This will save me the burden of implementing the same thing in both PRs and maintaining it until one gets merged, but rest assured I'll do the refactor before next release
|
||
def _register_attention_processors_metadata(): | ||
# AttnProcessor2_0 | ||
AttentionProcessorRegistry.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, only this and BasicTransformerBlock
is relevant, since modular diffusers only supports SDXL. The remaining is from copying but we keep it to avoid merge conflict since FirstBlockCache PR will most likely be merged before modular diffusers
return noise_cfg | ||
|
||
|
||
def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is easier to work with if we:
- provide a default method here on guilder_utils.py to deal with a list of inputs like you specified here: each element could be a tensor or tuples/list of tensors - this logic should be mostly the same for different guiders, no?
- let each specific guider class to define how to prepare each input element
basically the method here become something like this, would this make sense?
def prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> Tuple[List[torch.Tensor], ...]:
"""
Prepares the inputs for the denoiser by processing each argument individually using a helper method.
"""
list_of_inputs = []
for arg in args:
if isinstance(arg, (tuple, list))
if len(args) != 2:
raise ValueError("...")
elif not isinstance(arg, Torch.Tensor):
raise ValueError("...")
processed_input = self.prepare_input_single(arg, num_conditions)
list_of_inputs.append(processed_input)
return tuple(list_of_inputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll update the implementations
added_cond_kwargs=data.added_cond_kwargs, | ||
return_dict=False, | ||
)[0] | ||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to do something like this?
noise_pred_outputs = []
for batch_index, (...) in enumerate(zip(...):
latents_i = ...
noise_pred = pipeline.unet(..)
noise_pred_outputs = self.guilder.prepare_and_add_output(pipeline.unet, noise_pred, noise_pred_outputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.
I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think
if self._is_cfgpp_enabled(): | ||
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! | ||
pred_cond = self._preds["pred_cond"] | ||
pred_uncond = self._preds["pred_uncond"] | ||
diff = pred_uncond - pred_cond | ||
pred = pred + diff * self.guidance_scale * self._sigma_next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original repository implements CFG++ in a different way. I wanted to try and make it work without really modifying all our schedulers, and so it's done this way. The math works out the same.
For context, in our schedulers, we do:
new_sample = sample + model_output_after_cfg * (sigmas[i + 1] - sigmas[i])
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_after_cfg * sigmas[i + 1]
What we need to do for CFG++ is this instead:
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_uncond * sigmas[i + 1]
(This is only for EulerDiscreteScheduler and will differ for other schedulers)
After a little bit of working it out on paper, I found that some different schedulers don't really have to be modified if we add and subtract some terms after the scheduler step. We will need to have some specialized code (it can either exist in this file or the scheduler file) to add/subtract the right terms for each scheduler, so LMK how you think we should do it
Nevermind, it's better to just do this in the scheduler
return noise_cfg | ||
|
||
|
||
def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll update the implementations
added_cond_kwargs=data.added_cond_kwargs, | ||
return_dict=False, | ||
)[0] | ||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.
I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think
Also cc @DN6 for all the custom hook implementations |
@@ -668,7 +675,38 @@ def step( | |||
dt = self.sigmas[self.step_index + 1] - sigma_hat | |||
|
|||
prev_sample = sample + derivative * dt | |||
|
|||
if _use_cfgpp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we are hoping to find a scalable solution that can provide maximium support for community creativity. It isn't scalable if it requires code change into schedulers.
I think it can be manipulated inside guider, no? since, we have all the variables in pipeline state and all the components in model states, which you can use to access scheduler and tbe sigmas counter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Euler, yes, it is easy add a correction term outside the scheduler step and make it work -- this is how it was originally implemented.
For DDIM, DPM++, and all the others, it quickly gets very complicated to handle all the correction terms correctly since you need to recalculate a lot of variables for the original model_output
, subtract them out, calculate the correct variables using model_pred_uncond
, add that in. I don't think that having specialized code in the guider to handle all usable schedulers, probably using isinstance
checks, is a good approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, agree it should not be in guiders,
there is a dependency between guider & scheduler, our scheduler implementation are not aware of guidance approach since they were all designd to work with CFG
we can revist this last, but I would more lean towards making new schedulers for CFG++ since it basically requires a new step
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try to see if there's a way to breakdown existing schedulers into smaller methods, and consider the possiblity of overriding the behaviour given certain params from user. Fully re-implementing each CFG++ supported scheduler will probably just become combinatorial explosion hell. There is also a need to consider more techniques that come up, which might require tweaking just small aspects of the scheduler, and we should be able to make the experience of such integration better/easier ("Assemble like Lego" for modular diffusers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree,
not sure about breaking existing schedulers into smaller methods, though, it is already pretty small. I think allowing overriding should be sufficient:) we currently already sort of allow override set_timesteps by accepting custom timesteps created by the users, but it is a bit hacky/not very nice.
we should find a way to support different set_timesteps & step methods very easily (maybe something similar to attention processor, but should be much simpler)
how about we only support one scheduler for CFG++ in this PR and we can do a refactor on scheduler follow-up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i left a little bit more comments, for discussion only at this point. no need to do anything for now. let's find a design we are happy with first
@@ -668,7 +675,38 @@ def step( | |||
dt = self.sigmas[self.step_index + 1] - sigma_hat | |||
|
|||
prev_sample = sample + derivative * dt | |||
|
|||
if _use_cfgpp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, agree it should not be in guiders,
there is a dependency between guider & scheduler, our scheduler implementation are not aware of guidance approach since they were all designd to work with CFG
we can revist this last, but I would more lean towards making new schedulers for CFG++ since it basically requires a new step
function
if self._num_outputs_prepared > self.num_conditions: | ||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") | ||
key = self._input_predictions[self._num_outputs_prepared - 1] | ||
self._preds[key] = pred |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try not to store tensors inside guider class, unless we have to
this can go into the guider_data
if we decide to make one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be possible. The guider decides what batches of input it processes - this may be a batch that results in either pred_cond
, pred_uncond
, pred_cond_skip
, and so on. The guider will need to maintain this state information (i.e. which batch of data it is currently processing), but the modular pipeline can pull this info and maintain the output dict. If this sounds good, I'll update the implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! let's try:)
# prepare latents for controlnet using the guider | ||
data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) | ||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe easier to put all the batched input into its own data class, something like this
pipeline.guider.set_input_fields(
"latents" = "latents",
"prompt_embeds" = ("prompt_embeds", "negative_prompt_embeds"),
# we can make a data field to indicate it is conditional or not
"is_uncond" = (False, True)
...
)
# this should return a list, or tuple
batched_guider_data= pipeline.guider.prepare_inputs(data)
for batch in batched_guider_data:
# instead of latents_i, we can access and update via batch.latents, which corresponding to guider_data[i].latents
batch.latents = pipeline.scheduler.scale_model_input(batch.latents, t)
...
added_cond_kwargs = {
"text_embeds": batch.pooled_prompt_embeds,
"time_ids": batch.add_time_ids,
}
....
if batch.is_uncond and data.guess_mode:
down_block_res_samples = [torch.zeros_like(d) for d in down_block_res_samples]
else:
down_block_res_samples, mid_block_res_sample = pipeline.controlnet(batch.latents, ...)
...
# each batch has its own model_output
batch.noise_pred = pipeline.unet(batch.latents, ...)
# Perform guidance
# I think we can combine the guilder.prepare_outputs & guider forward pass
data.noise_pred = pipeline.guider(batched_guider_data, ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# we can make a data field to indicate it is conditional or not
Since there has been no concrete use case of using more than a conditional and unconditional batch of data, I think we could simply enforce the convention that:
- if one value is passed, it can either correspond to conditional data or unconditional data. Does not really matter which because the user controls what they are passin
- if two values are passed, first value corresponds to conditional data and second value corresponds to unconditional data.
Do you know of an example where more than one of each cond/uncond is used, in order to have the is_uncond
identifier?
# this should return a list, or tuple
batched_guider_data= pipeline.guider.prepare_inputs(data)
This sounds good to me. We provide all the input fields as available in data
and pull values out of it when prepare_inputs
is called within the guider. This way, guider has access to data
and can further pull more information (such as excite tokens required in attend-and-excite)
# each batch has its own model_output
Sounds good to me. Each batch in batched_guider_data will need an associated ID for this to work. For example, for SLG, each batch will need to be marked either pred_cond
, pred_uncond
and pred_cond_skip
(but either of these could be dynamically disabled/enabled by changing values with callback or directly on the guidance object)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there has been no concrete use case of using more than a conditional and unconditional batch of data, I think we could simply enforce the convention that:
if one value is passed, it can either correspond to conditional data or unconditional data. Does not really matter which because the user controls what they are passin
if two values are passed, first value corresponds to conditional data and second value corresponds to unconditional data.
sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another thing I have in mind and will try play around a little bit this week is to create a
LoopSequentialPipeineBlocks
that's similar to SequentialPipelineBlocks
(
class SequentialPipelineBlocks: |
this way the denoising loop itself will be modular too, e.g. you can just add controlnet/inpaint into your denoise block instead of rewrite a new denoisingBlock that does these things ; and this way we do not have to support additional callbacks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(also, I won't add more commit to the refactor PR until this PR is merged, so don't worry about merge conflicts)
The following methods are currently supported:
Note: PAG is implemented as Skip Layer Guidance and does not have its own guider implementation. The equivalent SLG initialization is:
Note: STG is also implemented as Skip Layer Guidance:
skip_attention=False, skip_ff=True, skip_attention_scores=False
skip_attention=True, skip_ff=False, skip_attention_scores=False
skip_attention=True, skip_ff=True, skip_attention_scores=False
skip_attention=False, skip_ff=False, skip_attention_scores=True
(essentially PAG)Note: You can use different SLG configs for different parts of the model. Create multiple configs and pass as a list to
skip_layer_config
Minimal all guiders testing script
YiYi's modified full test script